import os
import json
import pickle
import argparse
from PIL import Image
from tqdm import tqdm

import torch
from decord import VideoReader, cpu
from transformers import AutoImageProcessor, AutoModel, AutoProcessor, Blip2ForImageTextRetrieval

def parse_arguments():
    parser = argparse.ArgumentParser(description='Video feature extraction and image-text similarity calculation.')

    # --------------- Setting ---------------
    """ 
    !!! replace your dataset path here   https://github.com/MME-Benchmarks/Video-MME
    """
    parser.add_argument('--dataset_path', type=str, default='datasets/videomme', help='Root directory of the dataset')
    # --------------- Setting ---------------

    parser.add_argument('--json_file', type=str, default='./videomme_json_file.json', help='Path to the annotation JSON file')
    parser.add_argument('--output_dir', type=str, default='./score_and_features', help='Output directory for features and scores')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)')
    parser.add_argument('--start_index', type=int, default=0, help='Start index for processing')
    parser.add_argument('--end_index', type=int, default=-1, help='End index; -1 means process all')
    return parser.parse_args()

def load_models(device):
    # Load DINOv2 model and processor
    dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device).eval()
    dinov2_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
    # Load BLIP2 model and processor
    blip2_model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16).to(device).eval()
    blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
    return {
        'dinov2': (dinov2_model, dinov2_processor),
        'blip2': (blip2_model, blip2_processor)
    }

def extract_dinov2_features(image, model, processor, device):
    # Extract feature from an image using DINOv2 and mean pooling
    inp = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model(**inp)
        features = out.last_hidden_state.mean(dim=1).cpu()
    return features

def compute_blip2_similarities(frames, query, blip2_model, blip2_processor, device):
    # Compute image-text similarity for every frame and a textual query
    similarities = []
    for img in frames:
        inp = blip2_processor(images=img, text=query, return_tensors="pt", truncation=True).to(device, torch.float16)
        with torch.no_grad():
            res = blip2_model(**inp, use_image_text_matching_head=True)
            # Use the second value as the matching probability
            match_prob = torch.nn.functional.softmax(res.logits_per_image, dim=1)[0, 1].item()
        similarities.append(match_prob)
    return similarities

def process_video(video_path, video_id, questions, models, device):
    # Process a single video for feature extraction and similarity calculation

    # Read and sample frames using 1 FPS
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    fps = vr.get_avg_fps()
    frame_indices = list(range(0, len(vr), int(fps)))
    frames = [Image.fromarray(vr[i].asnumpy()) for i in frame_indices if i < len(vr)]

    # Extract DINOv2 features for each frame
    dinov2_model, dinov2_proc = models['dinov2']
    all_dino_features = [extract_dinov2_features(img, dinov2_model, dinov2_proc, device) for img in frames]
    dino_features = torch.cat(all_dino_features, dim=0).numpy()

    # Compute BLIP2 similarities for each question
    blip2_model, blip2_proc = models['blip2']
    question_results = []
    for idx, q in enumerate(questions):
        query2 = q['question'] + " " + " ".join(q["options"])
        similarities = compute_blip2_similarities(frames, query2, blip2_model, blip2_proc, device)
        question_results.append({
            'question_index': idx,
            'options': q["options"],
            'query2': query2,
            'blip2_similarities': similarities
        })

    return {
        'video_id': video_id,
        'video_path': video_path,
        'frame_indices': frame_indices,
        'num_frames': len(frames),
        'questions': question_results,
        'features': {'dinov2': dino_features}
    }

def save_results(results, output_dir, video_id):
    # Save similarity scores and features to files
    out_dir = os.path.join(output_dir, video_id)
    os.makedirs(out_dir, exist_ok=True)
    # Save similarity scores
    with open(os.path.join(out_dir, 'similarity_scores.json'), 'w', encoding='utf-8') as f:
        json.dump({
            'video_id': results['video_id'],
            'video_path': results['video_path'],
            'frame_indices': results['frame_indices'],
            'num_frames': results['num_frames'],
            'questions': results['questions']
        }, f, ensure_ascii=False, indent=2)
    # Save features
    for name, arr in results['features'].items():
        with open(os.path.join(out_dir, f"{name}_features.pkl"), 'wb') as f:
            pickle.dump(arr, f)
    print(f"Saved to: {out_dir}")

def main(args):
    # Load data from JSON
    with open(args.json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # Compose full video paths
    all_data = [{
        **item,
        "full_video_path": os.path.join(args.dataset_path, 'data', item['url'] + '.mp4')
    } for item in data]

    start_idx = args.start_index
    end_idx = args.end_index if args.end_index != -1 else len(all_data)
    selected_data = all_data[start_idx:end_idx]
    print(f"Processing video entries from {start_idx} to {end_idx - 1}")

    models = load_models(args.device)
    for item in tqdm(selected_data, desc="Processing videos"):
        if not os.path.exists(item['full_video_path']):
            print(f"Skipped missing file: {item['full_video_path']}")
            continue
        res = process_video(
            video_path=item['full_video_path'],
            video_id=item.get('video_id', item['url']),
            questions=item['questions'],
            models=models,
            device=args.device
        )
        save_results(res, args.output_dir, res['video_id'])

    print("All processing finished.")

if __name__ == '__main__':
    args = parse_arguments()
    main(args)